{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1a0cf63",
   "metadata": {},
   "source": [
    "# Test score fusion and calibration for Speaker Verification (SV)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abdc0067",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "os.chdir('../..')\n",
    "sys.path.insert(1, os.path.join(sys.path[0], '../..'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fccb6bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rc('pdf', fonttype=42)\n",
    "\n",
    "from notebooks.notebooks_utils import (\n",
    "    load_models,\n",
    "    evaluate_models,\n",
    "    create_metrics_df\n",
    ")\n",
    "\n",
    "from notebooks.evaluation.sv_visualization import scores_distribution\n",
    "from notebooks.evaluation.ScoreCalibration import ScoreCalibration\n",
    "\n",
    "from sslsv.evaluations.CosineSVEvaluation import CosineSVEvaluation, CosineSVEvaluationTaskConfig\n",
    "from sslsv.evaluations.CosineSVEvaluation import SpeakerVerificationEvaluation, SpeakerVerificationEvaluationTaskConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7c468fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = load_models(\n",
    "    [\n",
    "        'models/ssl/voxceleb2/simclr/simclr_enc-ECAPATDNN-1024_proj-none_t-0.03/config.yml',\n",
    "        'models/ssl/voxceleb2/dino/dino_enc-ECAPATDNN-1024_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/config.yml'\n",
    "    ],\n",
    "    override_names={\n",
    "        'models/ssl/voxceleb2/simclr/simclr_enc-ECAPATDNN-1024_proj-none_t-0.03'   : 'simclr',\n",
    "        'models/ssl/voxceleb2/dino/dino_enc-ECAPATDNN-1024_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04' : 'dino'\n",
    "    },\n",
    "    checkpoint_name=\"model_avg.pt\"\n",
    ")\n",
    "\n",
    "# \"SimCLR\":       \"models/ssl/voxceleb2/simclr/simclr_enc-ECAPATDNN-1024_proj-none_t-0.03/\",\n",
    "# \"MoCo\":         \"models/ssl/voxceleb2/moco/moco_enc-ECAPATDNN-1024_proj-none_Q-32768_t-0.03_m-0.999/\",\n",
    "# \"SwAV\":         \"models/ssl/voxceleb2/swav/swav_enc-ECAPATDNN-1024_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1/\",\n",
    "# \"VICReg\":       \"models/ssl/voxceleb2/vicreg/vicreg_enc-ECAPATDNN-1024_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1/\",\n",
    "# \"DINO\":         \"models/ssl/voxceleb2/dino/dino_enc-ECAPATDNN-1024_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/\",\n",
    "# \"Supervised\":   \"models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2/\","
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "144dd8c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "evals = evaluate_models(models, CosineSVEvaluation, CosineSVEvaluationTaskConfig(), return_evals=True)\n",
    "create_metrics_df(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "025d7fb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FusedAndCalibratedSVEvaluation(SpeakerVerificationEvaluation):\n",
    "    \n",
    "    def __init__(self, evaluations, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        \n",
    "        self.evaluations = evaluations\n",
    "        self.sc = ScoreCalibration(evaluations)\n",
    "\n",
    "    def _prepare_evaluation(self):\n",
    "        self.sc.train()\n",
    "    \n",
    "    def _get_sv_score(self, a, b):\n",
    "        scores = [evaluation._get_sv_score(a, b) for evaluation in self.evaluations]\n",
    "        score = self.sc.predict(torch.tensor(scores).unsqueeze(0))\n",
    "        return score.detach().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7793749",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation = FusedAndCalibratedSVEvaluation(\n",
    "    evaluations=evals,\n",
    "    model=None,\n",
    "    config=evals[0].config,\n",
    "    task_config=SpeakerVerificationEvaluationTaskConfig()\n",
    ")\n",
    "\n",
    "models['final'] = {\n",
    "    'metrics': evaluation.evaluate(),\n",
    "    'scores': evaluation.scores,\n",
    "    'targets': evaluation.targets\n",
    "}\n",
    "\n",
    "create_metrics_df(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "156450d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_distribution(models, use_angle=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_plot",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
